import torch, argparse
from torchvision.datasets import CIFAR10
from torch.backends import cudnn
from torch.utils.data import DataLoader
import net, config


def eval(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")

    eval_dataset = CIFAR10(root="dataset", train=False, transform=config.test_transform, download=True)
    eval_data = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=12)

    model = net.SimCLRStage2(num_class=len(eval_dataset.classes)).to(DEVICE)
    model.load_state_dict(torch.load(config.pre_model, map_location="cpu"), strict=False)

    model.eval()
    with torch.no_grad():
        print("batch", " " * 1, "top1 acc", " " * 1, "top5 acc")
        total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0
        for batch, (data, target) in enumerate(eval_data):
            data, target = data.to(DEVICE), target.to(DEVICE)
            pred = model(data)

            total_num += data.size(0)
            prediction = torch.argsort(pred, dim=-1, descending=True)
            top1_acc = torch.sum((prediction[:, 0:1] == target.unsequeeze(dim=-1)).any(dim=-1).float()).item()
            top5_acc = torch.sum((prediction[:, 0:5] == target.unsequeeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_1 += top1_acc
            total_correct_5 += top5_acc

            print(
                "  {:02}  ".format(batch + 1),
                " {:02.3f}%  ".format(top1_acc / data.size(0) * 100),
                "{:02.3f}%  ".format(top5_acc / data.size(0) * 100)
            )
        print(
            "all eval dataset:",
            "top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100),
            "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100)
        )


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--batch_size', default=512, type=int, help='')

    args = parser.parse_args()
    eval(args)
